#' @include utilities.R
NULL
#' Ggplots of Fitted Flexible Survival Models
#' @description Create ggplot2-based graphs for flexible survival models.
#' @inheritParams ggsurvplot_arguments
#' @param fit an object of class \code{flexsurvreg}.
#' @param data the data used to fit survival curves.
#' @param fun the type of survival curves. Allowed values include "survival"
#'   (default) and "cumhaz" (for cumulative hazard).
#' @param summary.flexsurv (optional) the summary of the \code{flexsurvreg}
#'   object as generated by the function \code{summary()}.
#' @param size line size for the flexible survival estimates.
#' @param conf.int,conf.int.flex logical. If TRUE, add confidence bands for
#'   flexible survival estimates.
#' @param conf.int.km same as \code{conf.in.flex} but for the kaplan-meier
#'   estimates.
#' @param ... additional arguments passed to the function \code{\link{ggsurvplot}()}.
#' @author Alboukadel Kassambara, \email{alboukadel.kassambara@@gmail.com}
#' @return a ggsurvplot
#' @examples
#' \donttest{
#' if(require("flexsurv")) {
#' fit <- flexsurvreg(Surv(rectime, censrec) ~ group,
#'                    dist = "gengamma", data = bc)
#' ggflexsurvplot(fit)
#' }
#' }
#'
#' @name ggflexsurvplot
#' @rdname ggflexsurvplot
#' @export
ggflexsurvplot <- function(fit, data = NULL,
                           fun = c("survival", "cumhaz"),
                           summary.flexsurv = NULL,
                           size = 1, conf.int = FALSE,
                           conf.int.flex = conf.int, conf.int.km = FALSE,
                           legend.labs = NULL,
                           ...
                           )

  {

  if (!requireNamespace("flexsurv", quietly = TRUE)) {
    stop("flexsurv package needed for this function to work. Please install it.")
  }

  if(!inherits(fit, "flexsurvreg"))
    stop("Can't handle an object of class ", class(fit))
  fun <- match.arg(fun)

  data <- .get_data(fit, data = data, complain = FALSE)

  summ <- .summary_flexsurv(fit, type = fun,
                            summary.flexsurv = summary.flexsurv)
  .strata <- summ$strata
  n.strata <- .strata %>% .levels() %>% length()

   fit.ext <- .extract.survfit(fit)
   surv.obj <- fit.ext$surv
   surv.vars <- fit.ext$variables
   .formula <- fit.ext$formula
   isfac <- .is_all_covariate_factor(fit)
   if(!all(isfac)){
     .formula <- .build_formula(surv.obj, "1")
     n.strata <- 1
   }

   if(n.strata == 1 & missing(conf.int))
     conf.int <- TRUE

  # Fit KM survival curves
  #::::::::::::::::::::::::::::::::::::::
  x <- do.call(survival::survfit, list(formula = .formula, data = data))
  fun <- if(fun == "survival") NULL else fun
  ggsurv <- ggsurvplot_core(x, data = data, size = 0.5,
                            fun = fun, conf.int = conf.int.km,
                            legend.labs = legend.labs, ...)

  # Overlay the fitted models
  #::::::::::::::::::::::::::::::::::::::

  # Check legend labels if specified
  if(!is.null(legend.labs)){

    if(n.strata != length(legend.labs))
      stop("The length of legend.labs should be ", n.strata )

    summ$strata <- factor(summ$strata,
                          levels = .levels(.strata),
                          labels = legend.labs)
  }

  time <- est <- strata <- lcl <- ucl <- NULL
  ggsurv$plot <- ggsurv$plot +
    geom_line(aes(time, est, color = strata),
              data = summ, size = size)

  if(conf.int.flex)
    ggsurv$plot <- ggsurv$plot +
    geom_line(aes(time, lcl, color = strata), data = summ,
              size = 0.5, linetype = "dashed")+
    geom_line(aes(time, ucl, color = strata), data = summ,
              size = 0.5, linetype = "dashed")


  ggsurv

}


.summary_flexsurv <- function(fit, type = "survival", summary.flexsurv = NULL)
  {

  summ <- summary.flexsurv

  if(is.null(summary.flexsurv))
    summ <- summary(fit, type = type)

  if(length(summ) == 1){
    summ <- summary(fit)[[1]] %>%
      dplyr::mutate(strata = "All")
  }

  else{
    .strata <- names(summ)
    summ <- purrr::pmap(list(.strata, summ),
                        function(.s, .summ){dplyr::mutate(.summ, strata = .s )}
                        )
    summ <- dplyr::bind_rows(summ)
    summ$strata <- factor(summ$strata, levels = .strata)
  }

  summ

}


# Check if all covariates are factor or character vector
.is_all_covariate_factor <- function(fit){
  x <- fit
  mf <- stats::model.frame(x)
  Xraw <- mf[,attr(mf, "covnames.orig"), drop=FALSE]
  dat <- x$data
  sapply(Xraw,is_factor_or_character)
}

is_factor_or_character <- function(x){
  is.facet(x) | is.character(x)
}
